In [8]:
# !pip install kaleido numpy pandas neurokit2 plotly seaborn ts2vg pyxdf
Imports and helper functios¶
In [9]:
import json
import numpy as np
import pandas as pd
import neurokit2 as nk
import plotly.io as pio
import plotly.express as px
from _plotly_utils.colors import n_colors
from matplotlib import pyplot as plt
import plotly.graph_objs as go
import warnings
import glob
from operator import itemgetter
import os
from util import plot_data, plot_channels, plot_gantt, markers_to_gantt, plot_epoch, plot_correlation_matrix, \
seconds_to_samples, samples_to_seconds, g
from IPython.display import display
pio.renderers.default = 'notebook_connected+jupyterlab'
In [10]:
# Adapted from Neurokit's read_xdf() to work with the data from this experiment
# https://neuropsychology.github.io/NeuroKit/functions/data.html#neurokit2.data.read_xdf
def read_xdf(subject, upsample=2, fillmissing=None):
"""**Read and tidy an XDF file**"""
def get_markers(markers_stream):
markers = markers_stream['time_series']
assert all(len(marker) == 1 for marker in markers), 'Warning: There is an event containing more than one marker'
markers = [marker[0] for marker in markers]
designs = ["fair", "dark"]
dark_first = json.load(open(g(subject, '*meta.json')[0], 'r'))['darkFirst']
if dark_first:
designs = designs[::-1]
events = [
"app/start", "cookies/start", "cookies/end",
"geolocation/start", "geolocation/end",
"notification/start", "notification/end",
"travelProtection/start", "travelProtection/end",
"newsletter/start", "newsletter/end", "app/end"
]
expected_marker_order = [f"{d}/{e}" for d in designs for e in events]
# Only keep unique markers occurring in expected order
relevant_indices = []
previous_index = 0
for marker in expected_marker_order:
if any(marker in item for item in markers[previous_index:]):
index = next((i for i in range(previous_index, len(markers)) if marker in markers[i]), None)
if index is None:
raise ValueError(f"No element containing '{marker}' found in markers from index {previous_index}")
relevant_indices.append(index)
previous_index = index
# else:
# print(f'Missing marker: {marker}')
timestamps = itemgetter(*relevant_indices)(markers_stream['time_stamps'])
markers = itemgetter(*relevant_indices)(markers)
return markers, timestamps
try:
import pyxdf
except ImportError:
raise ImportError(
"The 'pyxdf' module is required for this function to run. ",
"Please install it first (`pip install pyxdf`).",
)
# Load file
print(f"Reading xdf file for subject: {subject}")
streams, header = pyxdf.load_xdf(g(subject, '*.xdf')[0])
# Remove any empty streams
streams = [stream for stream in streams if len(stream['time_series'])]
# Process markers stream first
markers_stream = next(filter(lambda stream: isinstance(stream['time_series'], list), streams))
streams.remove(markers_stream)
markers, timestamps = get_markers(markers_stream)
# print(f"Markers: {markers}")
# print(f"Timestamps: {timestamps}")
# Get the time range for analysis (from first to last marker)
min_marker_time = min(timestamps)
max_marker_time = max(timestamps)
# Find the actual data range across all streams
all_stream_times = []
for stream in streams:
all_stream_times.extend(stream["time_stamps"])
data_start_time = min(all_stream_times)
data_end_time = max(all_stream_times)
# print(f"Data time range: {data_start_time} to {data_end_time}")
# print(f"Marker time range: {min_marker_time} to {max_marker_time}")
# Use the overlap between data and markers as the analysis window
analysis_start = max(data_start_time, min_marker_time)
analysis_end = min(data_end_time, max_marker_time)
# print(f"Analysis window: {analysis_start} to {analysis_end}")
# Set offset to analysis start
offset = analysis_start
markers_df = pd.DataFrame(markers, columns=['marker'])
markers_df["marker"] = markers_df["marker"].str.split("/").str[3:6].apply("/".join)
markers_df.index = pd.to_datetime(timestamps - offset, unit="s")
# Process other streams and convert to dataframes
dfs = []
for stream in streams:
time_mask = (stream["time_stamps"] >= analysis_start) & (stream["time_stamps"] <= analysis_end)
if not np.any(time_mask):
print(f"Warning: No data in analysis window for stream")
continue
filtered_timestamps = stream["time_stamps"][time_mask]
filtered_data = stream["time_series"][time_mask]
print(f"Stream data after filtering: {len(filtered_timestamps)} samples")
# print(f"Time range: {filtered_timestamps.min() - offset} to {filtered_timestamps.max() - offset}")
channels_info = stream["info"]["desc"][0]["channels"][0]["channel"]
cols = [channels_info[i]["label"][0] for i in range(len(channels_info))]
dat = pd.DataFrame(filtered_data, columns=cols)
# Apply offset to timestamps
dat.index = pd.to_datetime(filtered_timestamps - offset, unit="s")
dfs.append(dat)
if not dfs:
raise ValueError("No valid data found in analysis window")
# print(f"Number of data streams: {len(dfs)}")
# Store info of each stream
info = {
"sampling_rates_original": [float(s["info"]["nominal_srate"][0]) for s in streams],
"sampling_rates_effective": [float(s["info"]["effective_srate"]) for s in streams],
"datetime": header["info"]["datetime"][0],
"data": dfs,
"subject": subject,
}
# Merge all dataframes by timestamps
streams_df = dfs[0]
for i in range(1, len(dfs)):
streams_df = pd.merge(streams_df, dfs[i], how="outer", left_index=True, right_index=True)
streams_df = streams_df.sort_index()
# print(f"Merged dataframe shape: {streams_df.shape}")
# print(f"Time range: {streams_df.index.min()} to {streams_df.index.max()}")
# Resample and Interpolate
info["sampling_rate"] = int(np.max(info["sampling_rates_original"]) * upsample)
print(f"Target sampling rate: {info['sampling_rate']} Hz")
if fillmissing is not None:
fillmissing = int(info["sampling_rate"] * fillmissing)
# Create new index with evenly spaced timestamps
idx = pd.date_range(
streams_df.index.min(),
streams_df.index.max(),
freq=str(1000 / info["sampling_rate"]) + "ms"
)
# print(f"New index length: {len(idx)}")
# Reindex and interpolate
streams_df = streams_df.reindex(streams_df.index.union(idx))
# Only interpolate numeric columns
numeric_cols = streams_df.select_dtypes(include=[np.number]).columns
streams_df[numeric_cols] = streams_df[numeric_cols].interpolate(method="time", limit=fillmissing)
# Use the new evenly spaced index
streams_df = streams_df.reindex(idx)
# Final data validation
# print(f"Final dataframe shape: {streams_df.shape}")
# print(f"NaN counts by column:")
# for col in streams_df.columns:
# nan_count = streams_df[col].isna().sum()
# if nan_count > 0:
# print(f" {col}: {nan_count} NaNs ({nan_count / len(streams_df) * 100:.1f}%)")
return streams_df, markers_df, info
In [11]:
def get_signals_and_events(streams_df, markers_df, info, plot=True):
# Clean up the streams
# Rename some columns for clarity
rename_columns = {
'RAW0': 'EDA',
}
# Keep only the columns we want
columns_to_keep = [*rename_columns.values()]
streams_df = streams_df.rename(columns=rename_columns)
streams_df = streams_df[columns_to_keep]
if False:
plot_channels(
streams_df, markers_df, title='Raw Channel Data', hide_end_markers=True)
plot_gantt(markers_df)
# Process EDA
eda_signals, eda_info = nk.eda_process(streams_df['EDA'], sampling_rate=info["sampling_rate"],
method_cleaning="biosppy")
if plot:
nk.eda_plot(eda_signals, eda_info)
# Concatenate processed signals
signals = pd.concat([eda_signals], axis=1)
# Reindex markers with numeric index for further processing
nearest_indices = streams_df.index.get_indexer(markers_df.index, method='nearest')
markers_numindexed = markers_df.copy()
markers_numindexed.index = nearest_indices # Use integer sample numbers
if plot:
# Plot processed signals (only some columns)
columns_to_plot = ['EDA_Tonic', 'EDA_Phasic']
plot_channels(signals[columns_to_plot], markers_numindexed, title='Processed Channel Data',
hide_end_markers=True)
plot_gantt(markers_df)
# Remove the app markers
markers_to_remove = '|'.join(['/app'])
# markers_df = markers_df[~markers_df.marker.str.contains(markers_to_remove)]
markers_numindexed = markers_numindexed[~markers_numindexed.marker.str.contains(markers_to_remove)]
# Create events dictionary from gantt data for event-related analysis
gantt_data = markers_to_gantt(markers_numindexed)
labels = [d['marker'] for d in gantt_data]
designs, event_names = zip(*[d['marker'].split('/')[:2] for d in gantt_data])
events = dict(
onset=[d['start'] for d in gantt_data],
duration=[d['duration'] for d in gantt_data],
label=labels,
condition=labels,
designs=designs,
event_names=event_names
)
return signals, events
In [12]:
def get_epoch_features(signals, events, info, plot=True):
# Build epochs from events
epochs = nk.epochs_create(signals, events, sampling_rate=info["sampling_rate"], epochs_start=-1, epochs_end=5)
# Create a non-epoch signals dataframe
epoch_onsets = events['onset']
epoch_lengths = [len(e) for e in epochs.values()]
epoch_indices = []
for start, length in zip(epoch_onsets, epoch_lengths):
epoch_indices.extend(range(start, start + length))
non_epoch_signals = signals.drop(index=epoch_indices)
if plot:
for epoch in epochs.values():
plot_epoch(epoch, subplots=True)
fig = plt.gcf()
fig.suptitle(f"Epoch from -1 to 5 seconds for Event: {epoch['Condition'].values[0]}")
fig.show()
# Analyze epochs and extract features
bio_epoch_features = nk.bio_analyze(epochs, sampling_rate=info["sampling_rate"])
epoch_features = pd.concat([bio_epoch_features], axis=1)
# Calculate some additional features
for event in epochs.keys():
epoch_features.loc[event, 'total_duration'] = samples_to_seconds(
events['duration'][events['label'].index(event)], info["sampling_rate"])
epoch_features.loc[event, 'EDA_Phasic_Mean'] = epochs[event]['EDA_Phasic'].mean()
epoch_features.loc[event, 'EDA_Tonic_Max'] = max(epochs[event]['EDA_Tonic'])
epoch_features['event_name'] = [event_name for event_name in events['event_names']]
epoch_features['design'] = [design for design in events['designs']]
epoch_features = epoch_features.drop(columns=["Event_Onset", "Label", "Condition"])
# Do interval analysis to get features about the non-epoch signals
non_epoch_features = nk.bio_analyze(non_epoch_signals, sampling_rate=info["sampling_rate"], method="interval")
return epochs, epoch_features, non_epoch_features
Load and process all data¶
In [13]:
if not all([os.path.isfile(fname) for fname in
['non_epoch_features.csv', 'epoch_features.csv', 'epoch_signals.csv'
# , 'metadata.csv'
]]):
# Processed data hasn't been saved yet, so process raw data
non_epoch_features_dfs = []
dfs = [] # Epoch features
epoch_signal_dfs = []
meta_dfs = []
for i, subject in enumerate(sorted([d for d in os.listdir('data') if not d.startswith('.')])[1:], start=1):
# Read and clean subject's XDF file and extract streams and markers as dataframes
streams_df, markers_df, info = read_xdf(subject, upsample=2)
# print(streams_df)
# Further process the streams and markers
signals, events = get_signals_and_events(streams_df, markers_df, info, plot=1)
# Do event-related analysis to get features about epochs surrounding the events
epochs, epoch_features, non_epoch_features = get_epoch_features(signals, events, info, plot=0)
non_epoch_features.index = [i]
epoch_signal_df = pd.concat(epochs.values())
epoch_signal_df.insert(0, 'subject', i)
epoch_signal_df['event_name'] = [event_name for event_name in events['event_names'] for _ in
range(len(epochs[next(iter(epochs))]))]
epoch_signal_df['design'] = [design for design in events['designs'] for _ in
range(len(epochs[next(iter(epochs))]))]
# Load subject's questionnaire answers
# part1_answers, iuipc_score = load_questionnaire(subject)
df = (epoch_features
# .join(part1_answers)
)
df.insert(0, 'subject', i)
epoch_signal_dfs.append(epoch_signal_df)
# Load subject's demographics
# meta_df = load_demographics(subject)
# meta_df.index = [i]
# meta_df['iuipc'] = iuipc_score
non_epoch_features_dfs.append(non_epoch_features)
# dfs.append(df)
# meta_dfs.append(meta_df)
non_epoch_df = pd.concat(non_epoch_features_dfs)
non_epoch_df.to_csv('non_epoch_features.csv')
# df = pd.concat(dfs)
df.to_csv('epoch_features.csv')
epoch_signal_df = pd.concat(epoch_signal_dfs)
epoch_signal_df.to_csv('epoch_signals.csv')
#
# meta_df = pd.concat(meta_dfs)
# meta_df.to_csv('metadata.csv')
else:
# Processed data has already been saved, so load it
df = pd.read_csv('epoch_features.csv', index_col=0)
# non_epoch_df = pd.read_csv('non_epoch_features.csv', index_col=0)
epoch_signal_df = pd.read_csv('epoch_signals.csv', index_col=0)
# meta_df = pd.read_csv('metadata.csv', index_col=0)
Reading xdf file for subject: 002 Stream data after filtering: 377640 samples Target sampling rate: 2000 Hz